# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for cross_device_ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import threading

from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
from tensorflow.python.platform import tf_logging as logging

INSTANCE_KEY_START_NUMBER = 100


def aggregate_gradients_using_nccl(replica_grads):
  """Aggregate gradients using nccl allreduce."""
  agg_all_g_and_v = []
  for single_g_and_v in zip(*replica_grads):
    single_grads = [g for g, _ in single_g_and_v]
    agg_grads = nccl_ops.all_sum(single_grads)
    agg_all_g_and_v.append(
        [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])

  agg_all_g_and_v = list(zip(*agg_all_g_and_v))

  return agg_all_g_and_v


def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads):
  """Aggregate gradients using hierarchical copies.

  Args:
    avail_devices: available GPU devices.
    replica_grads: List of lists of (gradient, variable) tuples. The outer list
      is over replicas. The inner list is over individual gradients.

  Returns:
    The list of (aggregated_gradient, variable), where the gradient has been
      summed across all replicas and the variable is chosen from the first
      replica.
  """
  # This only works for DGX-1 type of machine topology
  # Device peer to peer matrix
  # DMA: 0 1 2 3 4 5 6 7
  # 0:   Y Y Y Y Y N N N
  # 1:   Y Y Y Y N Y N N
  # 2:   Y Y Y Y N N Y N
  # 3:   Y Y Y Y N N N Y
  # 4:   Y N N N Y Y Y Y
  # 5:   N Y N N Y Y Y Y
  # 6:   N N Y N Y Y Y Y
  # 7:   N N N Y Y Y Y Y
  agg_grads = []
  num_devices = len(avail_devices)
  # In the special case of DGX-1 machine topology, the two groups have equal
  # size.
  group_size = num_devices // 2
  for i, single_grads in enumerate(zip(*replica_grads)):
    group_0_main_device = i % num_devices
    group_1_main_device = (group_0_main_device + group_size) % num_devices
    if group_0_main_device < group_size:
      group_0_begin = 0
      group_1_begin = group_size
    else:
      group_0_begin = group_size
      group_1_begin = 0

    # Aggregate the first group.
    group_0_device_grads = single_grads[group_0_begin:
                                        group_0_begin + group_size]
    with ops.device(avail_devices[group_0_main_device]):
      group_0_agg_grads, _ = aggregate_single_gradient_using_copy(
          group_0_device_grads, False, False)

    # Aggregate the second group.
    group_1_device_grads = single_grads[group_1_begin:
                                        group_1_begin + group_size]
    with ops.device(avail_devices[group_1_main_device]):
      group_1_agg_grads, _ = aggregate_single_gradient_using_copy(
          group_1_device_grads, False, False)

    # Aggregate between the groups.
    with ops.device(avail_devices[group_0_main_device]):
      (agg_total_grads, _), _ = aggregate_single_gradient_using_copy(
          [group_0_agg_grads, group_1_agg_grads], False, False)

    # Broadcast the result back into the root of each group.
    with ops.device(avail_devices[group_0_main_device]):
      group_0_agg_grads_bcast = array_ops.identity(agg_total_grads)
    with ops.device(avail_devices[group_1_main_device]):
      group_1_agg_grads_bcast = array_ops.identity(agg_total_grads)

    agg_grads_bcast = []
    for j in range(len(single_grads)):
      with ops.device(avail_devices[j]):
        # Broadcast the result back to each member in the group from the root.
        if (group_0_main_device < group_size) == (j < group_size):
          src_device_grad = group_0_agg_grads_bcast
        else:
          src_device_grad = group_1_agg_grads_bcast
        agg_grads_bcast.append(array_ops.identity(src_device_grad))

    agg_grads.append(
        [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)])

  agg_grads = list(zip(*agg_grads))

  return agg_grads


def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
                                         check_inf_nan):
  """Calculate the average gradient for a shared variable across all replicas.

  Note that this function provides a synchronization point across all replicas.

  Args:
    grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
      (gradient, variable) pair within the outer list represents the gradient
      of the variable calculated for a single replica, and the number of pairs
      equals the number of replicas.
    use_mean: if True, mean is taken, else sum of gradients is taken.
    check_inf_nan: check grads for nans and infs.

  Returns:
    The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
      gradient has been averaged across all replicas. The variable is chosen
      from the first replica. The has_nan_or_inf indicates the grads has nan or
      inf.
  """
  grads = [g for g, _ in grad_and_vars]
  grad = math_ops.add_n(grads)

  if use_mean and len(grads) > 1:
    grad = array_ops.multiply(grad, 1.0 / len(grads))

  v = grad_and_vars[0][1]
  if check_inf_nan:
    has_nan_or_inf = array_ops.logical_not(
        array_ops.reduce_all(array_ops.is_finite(grads)))
    return (grad, v), has_nan_or_inf
  else:
    return (grad, v), None


# TODO(yuefengz): use random key starts to avoid reusing keys?
class CollectiveKeys(object):
  """Class that manages collective keys.

  We need to manage three different keys for collective:

  *Group key*: an integer key to identify the set of cooperative devices.
  Collective ops work under the same set of devices must using the same group
  key.

  *Instance key*: an integer key to identify the set of same counterpart of
  tensors on different devices in a device group that need to be all-reduced.

  This class is thread safe.
  """

  def __init__(self, group_key_start=1):
    """Initializes the object.

    Args:
      group_key_start: the starting integer of group key.
    """
    self._group_key = group_key_start
    self._group_key_table = {}
    self._instance_key_table = {}
    self._lock = threading.Lock()

  def get_group_key(self, devices):
    """Returns a group key for the set of devices.

    Args:
      devices: a list of canonical device strings in a collective group.

    Returns:
      int key uniquely identifying the set of device names.
    """
    key_id = hash(tuple(sorted(devices)))
    with self._lock:
      if key_id not in self._group_key_table:
        new_key = self._group_key
        self._group_key += 1
        self._group_key_table[key_id] = new_key
        self._instance_key_table[new_key] = {}
        for device in devices:
          self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
      return self._group_key_table[key_id]

  def get_instance_key(self, group_key, device):
    """Returns a new instance key for use in defining a collective op.

    You should call this once per each collective op of a collective instance.

    Args:
      group_key: the group key returned by get_group_key(). You should not
        assign the group key yourself.
      device: a canonical device string. It should be the device this collective
        op is on.

    Returns:
      a new instance key.

    Raises:
      ValueError: when the group key is invalid or the device is not in the
      group.
    """
    with self._lock:
      group = self._instance_key_table.get(group_key, None)
      if group is None:
        raise ValueError('group {} not found'.format(group_key))
      if device not in group:
        raise ValueError('{} not in group {}'.format(device, group_key))
      v = group[device]
      group[device] += 1
      return v

  def __deepcopy__(self, memo):
    # distribute_coordinator deep-copies the strategy object, so
    # CollectiveKeys needs to support deep copy as well.
    copied = CollectiveKeys()
    copied._group_key = self._group_key
    copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
    copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
    return copied


class CollectiveReplicaLauncher(object):
  """Launch collectives on one replica."""

  _use_scoped_allocator = True

  def __init__(self,
               group_key,
               group_size,
               collective_keys,
               device,
               executor=None):
    if executor and not executor.is_async():
      raise ValueError('executor must be async')
    self._group_key = group_key
    self._group_size = group_size
    self._collective_keys = collective_keys
    self._device = device
    self._executor = executor

  def _executor_scope(self):
    if context.executing_eagerly() and not self._executor:
      raise ValueError('collectives requires a async executor in eager mode')
    if context.executing_eagerly():
      return context.executor_scope(self._executor)
    return ops.NullContextmanager()

  def _control_input(self, control_input):
    if control_input is not None:
      return ops.control_dependencies([control_input])
    return ops.NullContextmanager()

  def all_reduce(self,
                 input_tensor,
                 control_input=None,
                 communication_hint='AUTO',
                 timeout=0):
    """All-reduce a dense tensor.

    This can be called in eager mode if a async executor is supplied when
    creating the launcher.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      control_input: if not None, add control edges between control_input and
        the all-reduce.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced tensor.
    """
    instance_key = self._collective_keys.get_instance_key(
        self._group_key, self._device)
    with self._executor_scope(), \
         ops.device(self._device), \
         self._control_input(control_input):
      return collective_ops.all_reduce(
          input_tensor,
          self._group_size,
          self._group_key,
          instance_key,
          communication_hint=communication_hint,
          timeout=timeout)

  def batch_all_reduce(self,
                       input_tensor_packs,
                       communication_hint='AUTO',
                       timeout=0):
    """Batch all-reduce dense tensors.

    This takes a list of batches of tensors. Using multiple batches have the
    benefit that it doesn't need to wait for all inputs to be ready to start the
    all-reduce.

    This can be called in eager mode if a async executor is supplied when
    creating the launcher.

    Args:
      input_tensor_packs: a list of lists of dense tensors.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      A flat list of reduced tensors.
    """
    # We don't batch with concat in eager. It's easy to get it wrong because
    # we need to avoid any numpy() calls on values produced by the async
    # executor. This effectively disables batching in eager, but it's unlikely
    # to all-reduce a large number of tensors in eager.
    batch_with_concat = (not self._use_scoped_allocator and
                         not context.executing_eagerly())
    outputs = []
    for pack in input_tensor_packs:
      # TODO(b/169168846): inserts a parallel all_gather to verify packings
      # are the same on each replica.
      if batch_with_concat:
        with ops.device(self._device):
          flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
          shapes = [array_ops.shape(t) for t in pack]
          if communication_hint == 'NCCL' and outputs:
            control_input = outputs[-1]
          else:
            control_input = None
          reduced = self.all_reduce(
              array_ops.concat(flat_tensors, axis=0), control_input,
              communication_hint, timeout)
          num_elements = [math_ops.reduce_prod(s) for s in shapes]
          flat_outputs = array_ops.split(reduced, num_elements, axis=0)
          for shape, flat_output in zip(shapes, flat_outputs):
            outputs.append(array_ops.reshape(flat_output, shape))
      else:
        # By placing all CollectiveReduce ops in a batch under single name
        # scope, we ensure they will be picked up by the `ScopedAllocator`
        # grappler optimizer and packed into a single all-reduce.
        with ops.name_scope('allreduce'):
          for input_tensor in pack:
            if communication_hint == 'NCCL' and outputs:
              control_input = outputs[-1]
            else:
              control_input = None
            outputs.append(
                self.all_reduce(input_tensor, control_input, communication_hint,
                                timeout))

    return outputs

  def all_gather(self,
                 input_tensor,
                 axis,
                 communication_hint='AUTO',
                 timeout=0):
    """All-gather a dense tensor.

    This method must be called inside a tf.function.

    Args:
      input_tensor: a dense tensor. It must have the same rank on all replicas,
        and dimensions other than `axis` need to be the same as well.
      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
        range [0, rank(value)).
      communication_hint: string providing hint to runtime for choosing
        collective implementation. Available options are `AUTO`, `NCCL`, and
        `RING`.
      timeout: a float. The timeout in seconds.

    Returns:
      The gathered Tensor.

    Raises:
      RuntimeError: if called in eager mode.
    """
    if context.executing_eagerly():
      raise RuntimeError('all_gather in eager mode is not supported')

    instance_key_tensor = self._collective_keys.get_instance_key(
        self._group_key, self._device)
    instance_key_shape = self._collective_keys.get_instance_key(
        self._group_key, self._device)
    with ops.device(self._device), \
         ops.control_dependencies([array_ops.identity(input_tensor)]):
      # 1. Transpose
      # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
      # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
      # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
      # place it back.
      perm_pre = array_ops.concat(
          ([axis], math_ops.range(axis),
           math_ops.range(axis + 1, array_ops.rank(input_tensor))),
          axis=0)
      input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
      # 2. Pad
      gathered_shape = collective_ops.all_gather(
          array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
          self._group_size,
          self._group_key,
          instance_key_shape,
          communication_hint,
          timeout=timeout)
      first_dims = gathered_shape[:, 0]
      full_axis_dim = math_ops.reduce_max(first_dims)
      padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)

      # 3. Gather
      gather_padded_out_tensor = collective_ops.all_gather(
          padded_input_tensor,
          self._group_size,
          self._group_key,
          instance_key_tensor,
          communication_hint,
          timeout=timeout)
      # 4. Unpad
      split_tensors = []
      for i in range(first_dims.shape[0]):
        start_pos = i * full_axis_dim
        split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
                                                      first_dims[i]])
      out_tensor_t = array_ops.concat(split_tensors, 0)

      # 5. Transpose back
      perm_after = array_ops.concat(
          (math_ops.range(1, axis + 1), [0],
           math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
          axis=0)
      return array_ops.transpose(out_tensor_t, perm=perm_after)

  def all_reduce_indexed_slices(self,
                                input_slices,
                                communication_hint='AUTO',
                                timeout=0):
    """All-reduce an IndexedSlices.

    This method must be called inside a tf.function.

    Args:
      input_slices: an IndexedSlices.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced IndexedSlices.

    Raises:
      RuntimeError: if called in eager mode.
    """
    if context.executing_eagerly():
      raise RuntimeError(
          'all_reduce_indexed_slices in eager mode is not supported')

    gather_length_key = self._collective_keys.get_instance_key(
        self._group_key, self._device)
    gather_indices_key = self._collective_keys.get_instance_key(
        self._group_key, self._device)
    gather_values_key = self._collective_keys.get_instance_key(
        self._group_key, self._device)
    reduce_densified_key = self._collective_keys.get_instance_key(
        self._group_key, self._device)

    # Current CollectiveAllGather implementations require input IndexedSlices to
    # have consistent length across the board, we handle the reduction of
    # IndexedSlices as follows:
    #   1. Gather the lengths of IndexedSlices from all participants.
    #   2. If they have consistent length, apply all_gather.
    #   3. Otherwise convert IndexedSlices to dense tensors and apply
    #      all_reduce.
    with ops.device(self._device):

      def all_gather():
        """Use all_gather to aggregate `IndexedSlices`."""
        all_values = collective_ops.all_gather(
            input_slices.values,
            self._group_size,
            self._group_key,
            gather_values_key,
            communication_hint,
            timeout=timeout)
        # Add control dependency to order the all-gather.
        control = [all_values] if communication_hint == 'NCCL' else []
        with ops.control_dependencies(control):
          all_indices = collective_ops.all_gather(
              input_slices.indices,
              self._group_size,
              self._group_key,
              gather_indices_key,
              communication_hint,
              timeout=timeout)
        return ops.IndexedSlices(
            values=all_values,
            indices=all_indices,
            dense_shape=input_slices.dense_shape)

      def densify_and_all_reduce():
        """Use all_reduce to aggregate `IndexedSlices`."""
        densified = ops.convert_to_tensor(input_slices)
        reduced = collective_ops.all_reduce(
            densified,
            self._group_size,
            self._group_key,
            reduce_densified_key,
            'Add',
            'Id', [0],
            communication_hint,
            timeout=timeout)
        # We have to convert dense grad to IndexedSlice because all_reduce()
        # and all_gather() must have the same return type as required by
        # control_flow_ops.cond.
        return ops.IndexedSlices(
            values=reduced,
            indices=math_ops.range(array_ops.shape(reduced)[0]),
            dense_shape=input_slices.dense_shape)

      length = array_ops.shape(input_slices.indices)
      all_lengths = collective_ops.all_gather(
          length,
          self._group_size,
          self._group_key,
          gather_length_key,
          communication_hint,
          timeout=timeout)
      return control_flow_ops.cond(
          math_ops.equal(
              math_ops.reduce_max(all_lengths),
              math_ops.reduce_min(all_lengths)), all_gather,
          densify_and_all_reduce)


def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
  """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
  if any(isinstance(v, ops.IndexedSlices) for v in values):
    return backprop.aggregate_indexed_slices_gradients(values)
  else:
    return accumulation_fn(values)


def divide_by_n_tensors_or_indexed_slices(value, n):
  if isinstance(value, ops.IndexedSlices):
    value = backprop.flatten_nested_indexed_slices(value)
    return ops.IndexedSlices(
        value.values / n, value.indices, value.dense_shape)
  else:
    return value / n


def copy_tensor_or_indexed_slices_to_device(value, device):
  with ops.device(device):
    if isinstance(value, ops.IndexedSlices):
      copied_values = array_ops.identity(value.values)
      copied_indices = array_ops.identity(value.indices)
      copied_shape = array_ops.identity(value.dense_shape)
      result = ops.IndexedSlices(copied_values, copied_indices, copied_shape)
    else:
      result = array_ops.identity(value)
  return result


def is_indexed_slices(value):
  if isinstance(value, ops.IndexedSlices):
    return True
  assert isinstance(value, value_lib.DistributedValues)
  return all(isinstance(v, ops.IndexedSlices) for v in value.values)


def split_by_sparsity(values):
  """Split values into dense and sparse values.

  Args:
    values: a list of tensors or `PerReplica`s.

  Returns:
    Four lists:
      a list of dense values, a list of their indices in `values` and
      a list of sparse values, a list of their indices in `values`.
  """
  dense_values = []
  dense_indices = []
  sparse_values = []
  sparse_indices = []
  for i, v in enumerate(values):
    if is_indexed_slices(v):
      sparse_values.append(v)
      sparse_indices.append(i)
    else:
      dense_values.append(v)
      dense_indices.append(i)
  return dense_values, dense_indices, sparse_values, sparse_indices


def stitch_values(values_and_indices_list):
  """Stitch values together according to their indices.

  Args:
    values_and_indices_list: a list of tuples of values and indices indicating
      the values and positions in the returned list.

  Returns:
    a stitched list of values.
  """
  length = 0
  for values_and_indices in values_and_indices_list:
    length += len(values_and_indices[0])

  result = [None] * length
  for values_and_indices in values_and_indices_list:
    if values_and_indices and values_and_indices[0]:
      for v, i in zip(*values_and_indices):
        assert result[i] is None
        result[i] = v
  return result


def group_by_size(input_tensors, bytes_per_pack):
  """Groups `input_tensors` into chunks of `bytes_per_pack`.

  The method preserves the original order of `input_tensors`. The grouping is
  best effort, each pack could have more or less bytes than `bytes_per_pack`.
  It only groups values with known shape.

  Args:
    input_tensors: a list of Tensor.
    bytes_per_pack: an integer.

  Returns:
    A list of packs of Tensor. All values are grouped into one pack if
    `bytes_per_pack` is zero or any of the value has unknown shape.
  """

  if bytes_per_pack == 0:
    return [input_tensors]
  packs = []
  last_pack_size = 0
  for value in input_tensors:
    num_elements = value.shape.num_elements()
    if num_elements is None:
      # Can't pack values with unknown shape.
      logging.warning(
          'not packing values due to the unknown or inconsistent shape of %s',
          value)
      return [input_tensors]
    size = num_elements * value.dtype.size
    # Try to keep each pack as close to bytes_per_pack as possible, while each
    # pack is at least bytes_per_pack large. I.E. we err on the side of having
    # few but large packs.
    if not packs or last_pack_size > bytes_per_pack:
      packs.append([])
      last_pack_size = 0
    packs[-1].append(value)
    last_pack_size += size
  return packs


def _pad_util(input_tensor, full_axis_dim):
  """Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
  missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
  tensor_rank = array_ops.rank(input_tensor)
  paddings_axis = [[0, missing_axis_dim]]
  paddings = array_ops.concat([
      paddings_axis,
      array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
  ],
                              axis=0)
  padded_input_tensor = array_ops.pad(input_tensor, paddings)
  return padded_input_tensor
